Skip to content

Change tensorflow model format to SavedModel to support sub-classed models#628

Closed
ascillitoe wants to merge 26 commits intoSeldonIO:masterfrom
ascillitoe:feature/tf_SavedModel
Closed

Change tensorflow model format to SavedModel to support sub-classed models#628
ascillitoe wants to merge 26 commits intoSeldonIO:masterfrom
ascillitoe:feature/tf_SavedModel

Conversation

@ascillitoe
Copy link
Contributor

@ascillitoe ascillitoe commented Sep 21, 2022

This PR changes the format we use to serialize TensorFlow models from the old HDF5 to the newer SavedModel format.

Motivation

As well as being the default (and recommended) TensorFlow model format, the SavedModel format has the advantage of supporting serialisation of sub-classed tensorflow models. That is, models constructed by subclassing tf.keras.Model, rather than by using tf.keras.Sequential, tf.keras.models.Model etc.

Limitations

  • The new format improves the handling of custom tensorflow objects (such as layers and models) slightly. If these are not passed to tf.keras.models.load_model via custom_objects (or registered with @tf.keras.utils.register_keras_serializable()) the model will be loaded as a keras.saving.saved_model.load.<model_class>*. This is a rough copy of the original serialized model, that behaves the same wrt inference, but cannot be cloned (which is done in a number of learned detectors such as ClassifierDrift). To load the fully-functional model, all custom objects must be supplied at load time.
  • The HiddenOutput class does not work for subclassed models. Therefore, subclassed models cannot be saved/loaded when layer is specified in the ModelConfig.

*in tensorflow>=2.9. In older versions, loading of the model will fail completely if the custom objects are not provided.

Main changes

  • Changed save_format from 'h5' to 'tf' in save_model and load_model, although stuck with h5 for the legacy save/load functions.
  • Removed support for passing a custom_objects dictionary via config since support for this was very flaky. Custom objects in the dictionary could only realistically be specified as registered object strings ('@mymodel etc). However, this is confusing as tensorflow already has its own @tf.keras.utils.register_keras_serializable() decorator.
  • load_detector now allows arbitrary kwargs, which are passed to tf.keras.models.load_model (or torch.load). This is to be used to provide the custom_objects at load time (see example below).
  • Added subclassed models to CI.

Example

Example notebook demonstrating serialisation of a detector with a subclassed tensorflow model. Observe how the custom objects must be passed to load_detector in order to avoid the error KeyError: 'layers'.

Backwards compatibility

tf.keras.models.load_model automatically detects whether a given model path represents a h5 model or SavedModel. This means we should be backwards compatible, in that we can simply move to saving SavedModel's, but still support loading of legacy h5 models.

TODO's

  • Better document limitations of SavedModel format in docs.
    - [ ] Add a more involved example of passing custom objects to load_detector. - More challenging than first envisaged; I wanted to demonstrate on the amazon example, where a subclassed ClassifierTF model is used, but saving here is not supported due to the use of tokenize_transformer. This is not related to subclassed models so would like to leave for a follow-up PR.
  • CHANGELOG.md - see Change tensorflow model format to SavedModel to support sub-classed models #628 (review)

Old notes etc

This notebook contains some experiments run to explore limitations wrt to the SavedModel format.

@ascillitoe ascillitoe added Type: Serialization Serialization proposals and changes WIP PR is a Work in Progress labels Sep 21, 2022
try: # legacy load_model behaviour was to return None if not found. Now it raises error, hence need try-except.
model = load_model(filepath, load_dir='encoder')
except FileNotFoundError:
except OSError:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to OSError because we are now relying on tf.keras.model.load_model to raise an error when loading fails, and this is what it raises...

elif isinstance(detector, (ChiSquareDrift, ClassifierDrift, KSDrift, MMDDrift, TabularDrift)):
if model is not None:
save_model(model, filepath, save_dir='encoder')
save_model(model, filepath, save_dir='encoder', save_format='h5')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stick to saving in h5 format for legacy saves...

@codecov-commenter
Copy link

codecov-commenter commented Sep 22, 2022

Codecov Report

Merging #628 (1dc7b61) into master (c0c5e64) will decrease coverage by 0.03%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #628      +/-   ##
==========================================
- Coverage   80.35%   80.33%   -0.03%     
==========================================
  Files         137      137              
  Lines        9300     9304       +4     
==========================================
+ Hits         7473     7474       +1     
- Misses       1827     1830       +3     
Flag Coverage Δ
macos-latest-3.9 76.81% <100.00%> (+0.01%) ⬆️
ubuntu-latest-3.10 80.22% <100.00%> (+<0.01%) ⬆️
ubuntu-latest-3.7 80.12% <100.00%> (+0.01%) ⬆️
ubuntu-latest-3.8 80.17% <100.00%> (+0.01%) ⬆️
ubuntu-latest-3.9 ?
windows-latest-3.9 76.81% <100.00%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
alibi_detect/saving/schemas.py 98.77% <ø> (-0.01%) ⬇️
alibi_detect/base.py 85.45% <100.00%> (ø)
alibi_detect/saving/_pytorch/loading.py 92.15% <100.00%> (+0.32%) ⬆️
alibi_detect/saving/_tensorflow/loading.py 85.44% <100.00%> (+0.23%) ⬆️
alibi_detect/saving/_tensorflow/saving.py 81.81% <100.00%> (-0.06%) ⬇️
alibi_detect/saving/loading.py 93.83% <100.00%> (-0.03%) ⬇️
alibi_detect/datasets.py 68.69% <0.00%> (-1.31%) ⬇️

@ascillitoe ascillitoe changed the title Change tensorflow model format from hdf5 to SavedModel Change tensorflow model format to SavedModel to support sub-classed models Jan 3, 2023
@ascillitoe ascillitoe added this to the v0.11.0 milestone Jan 3, 2023
@ascillitoe ascillitoe removed the WIP PR is a Work in Progress label Jan 18, 2023
@ascillitoe ascillitoe requested review from jklaise and mauicv January 19, 2023 09:52
Copy link
Contributor

@mauicv mauicv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ascillitoe
Copy link
Contributor Author

ascillitoe commented Jan 27, 2023

The commits from f61af31 onwards contain three primary changes (based on discussion with @jklaise ):

  1. Subclassed tf models must have been built/called before they can be saved. It was decided that attempting to perform a dummy call prior to saving was risky. Instead, the ValueError that occurs in alibi_detect.saving._tensorflow.save_model is caught and re-raised with a more informative error message.
  2. For custom layers or subclassed models, various errors can occur during inference or cloning if custom objects are not properly provided at load time. The errors are often unclear, and surface from a variety of places e.g. in detector predict methods etc. alibi_detect.saving._tensorflow.load_model runs some basic checks on the loaded model and raises a warning if any problems are detected. This allows problems to be discovered when the detector is first loaded, instead of having to wait until prediction time.
  3. A more prominent warning about providing custom objects at load time is added to the docs.

model that can be saved and loaded with `torch.save(..., pickle_module=dill)` and `torch.load(..., pickle_module=dill)`.
```{note}

- The {obj}`~alibi_detect.cd.tensorflow.HiddenOutput` utility class is not currently compatible with subclassed models.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what sense is it not? How could it be made compatible in the future?

Copy link
Contributor Author

@ascillitoe ascillitoe Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HiddenOutput works by building a new tf.keras.Model from the original model's input and layers attributes.

class HiddenOutput(tf.keras.Model):
def __init__(
self,
model: tf.keras.Model,
layer: int = -1,
input_shape: tuple = None,
flatten: bool = False
) -> None:
super().__init__()
if input_shape and not model.inputs:
inputs = Input(shape=input_shape)
model.call(inputs)
else:
inputs = model.inputs
self.model = Model(inputs=inputs, outputs=model.layers[layer].output)
self.flatten = Flatten() if flatten else tf.identity
def call(self, x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor:
return self.flatten(self.model(x))

This works fine for tf.keras.Model's constructed from tf.keras.Sequential etc as they have a pre-defined structure for these attributes. It doesn't work out-of-the-box for subclassed models, presumably because there are quite a few different ways you can construct these models internally. Probably worth opening an issue to explore this one further...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, feels like a docstring is required for HiddenOutput(and other utility classes) to describe behaivour/limitations. For another PR though.

Copy link
Contributor Author

@ascillitoe ascillitoe Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll open an issue shortly 👍🏻

Edit: #734

@@ -18,6 +18,7 @@

def load_model(filepath: Union[str, os.PathLike],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra functionality sneaking into this PR... Worth adding changelog entries to this PR so everything is documented and not missed upon release?

Copy link
Contributor Author

@ascillitoe ascillitoe Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% clear what you mean here? The functionality of passing kwargs to load_detector? It is extra functionality but is interlinked with the PR, as custom_objects needs to be passed to load_detector.

Edit: reading again, I see what you mean. Since we also pass kwarg's to pytorch. I could factor this out to a separate PR if preferred...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need, but would appreciate a changelog as part of the PR.

from alibi_detect.models.tensorflow.autoencoder import (AE, AEGMM, VAE, VAEGMM,
DecoderLSTM,
EncoderLSTM, Seq2Seq)
from alibi_detect.utils.tensorflow.misc import check_model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, but why not keep the function in this module since it's specifically used during loading?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmn I put it here so that its next to clone_model, which is quite related. But can move back to saving since we only use it there...

Comment on lines 128 to 135
# Check model cloning doesn't raise error
clone_model(model)

except Exception as error:
if raise_error:
raise ValueError(msg) from error
else:
warnings.warn(msg + f"Original error message: \n\t{error}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this works that well together, e.g. we check if cloning works but it any case, the error is never raised. So effectively this would blow up again when using inside detectors that do use cloning? Or is the intention to call this method there also?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem I am having here is that there are a large number of possible failure modes that can occur due to not specifying the custom_objects properly. Depending on the exact tf version, and whether the model itself is subclassed, or just layers, either ValueError, TypeError's or NotImeplemented errors are raised, and this is either during inference or cloning.

The widest net I've been able to cast is to simply test if cloning works, and to check if there might be problems at inference, check whether the model is a RevivedNetwork, which is what it's loaded as if custom objects are missing (I don't try to actually call the model as data isn't available at this point).

A wide net is nice in term of hopefully catching most errors, but has the downside of throwing errors when things might have actually worked. For example, the RevivedNetwork's generally work for inference but not cloning. So a user can actually get away with not passing custom_objects if the model is just to be used for preprocessing. Hence why I went with a warning instead of error...

Maybe one compromise is to raise a warning if the model is a RevivedNetwork, since this can cause issues, but raise the ValueError (with the more coherent message) if cloning fails?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, cloning failure is only relevant for a specific set of detectors, so somehow feels like it should be checked only for that subset. Maybe leave this as is and wrap the relevant detector calls to clone() in try/except with a customized Alibi error message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, just coming back on this, I've thought about it a bit more. There are just so many different failure modes, and the problem with just checking if cloneable is that there is one nasty failure mode where inference works, cloning works, BUT inference after cloning doesn't work 🤯 (this occurs in example 5 here, the custom call method is lost when cloning...).

This is what motivates casting the wide net by checking if a RevivedNetwork too. But then throwing a warning for this would prevent users getting away with cases that would otherwise work with no custom_objects provided (generally when no cloning).

I see two options:

  1. Leave pretty much as is. Downside at the moment is by swallowing the ValueError from cloning and raising as a warning, it can be somewhat lost in the later actual errors that occur (i.e. when cloning down for real).
  2. Turn the warning into an error. Be up front about requiring custom_objects to be passed in all cases where there are custom objects involved. This is slightly less convenient for the user, but means we don't have to worry about these many different failure modes that might also change in future versions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raising an error here would pretty much be equivalent to outright not supporting revived models? Perhaps it's not a bad idea, especially if that functionality is not well documented on tensorflow docs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. It would mean any models containing custom classes loaded without providing those custom classes (via registering, or custom_objects kwarg) would be pretty much guaranteed to fail. Thinking about it more this would probably be what I would lean towards. Although to be clear, this does also mean that, for example, if a user had a relatively simple tf.keras.Sequential, but with one custom layer, they would have to provide this layer at load time even if they don't use the model in a detector with any cloning...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensorflow behaviour with this has been so fluid in versions 2.9, 2.10 and 2.11 that I do think it'd be the safest option though... (for now)

raise FileNotFoundError(f'{model_name} not found in {model_dir.resolve()}.')
model = tf.keras.models.load_model(model_dir.joinpath(model_name), custom_objects=custom_objects)
# Load model
model = tf.keras.models.load_model(filepath, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we're throwing away the validation code for the existence of the model? Or is it done from higher up in another caller?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its not actually done from higher up. Rather, I realised that the validation might be superfluous since tf.keras.models.load_model already raises OSError: No file or directory found at test.h5 if a filepath to a .h5 model is passed and one doesn't exist, and OSError: SavedModel file does not exist at: test//{saved_model.pbtxt|saved_model.pb} if an directory is passed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK that makes sense.

def load_model(filepath: Union[str, os.PathLike],
filename: str = 'model',
custom_objects: dict = None,
layer: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, but I don't fully agree with omitting filename from these internal save/load functions. What we save for in function signature, we pay at every callsite, having to remember to do .joinpath(filename). (OTOH in the old behaviour, having a default model name is also likely not desirable as forgetting to set it would result in a perhaps unexpected default).

Copy link
Contributor Author

@ascillitoe ascillitoe Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty much the only reason I made this change is that the filename is only there for legacy loading (legacy as in, saving to .h5, and legacy as in loading files with different names such as encoder.h5). For the "modern" loading we simply do:

    if flavour == Framework.TENSORFLOW:
        model = load_model_tf(src, layer=layer, **kwargs)

So I saw it as a trade-off wrt to carrying around more complexity in load_model to facilitate the legacy functionality in load_detector_legacy, or add some complexity to the calls in load_detector_legacy to simplify the load_model function... (which just happens to be used for modern and legacy loading).

Similar story for save_model...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, noting that there's a few subtle changes in the behaviour of internal functions for saving/loading, need to be extra vigilant new bugs haven't been introduced.

p.s. below is very true though... tweaking anything to do with legacy save/load does bring the potential for bugs like the ones v0.10.5 fixed. I've run the same tests of loading old artefacts we ran in #732 and everything passes, but that isn't 100% comprehensive...

Copy link
Contributor

@jklaise jklaise left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM, but would appreciate a changelog with all the changes. Also, noting that there's a few subtle changes in the behaviour of internal functions for saving/loading, need to be extra vigilant new bugs haven't been introduced.

@ascillitoe ascillitoe modified the milestones: v0.11.0, v0.12.0 Jan 31, 2023
@ascillitoe
Copy link
Contributor Author

Postponing this to v0.12.0 so that we can combine it with deprecation of legacy saving and #723 .

@ascillitoe ascillitoe closed this by deleting the head repository Apr 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Type: Serialization Serialization proposals and changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants